
# %%
import torch
from torch import nn, Tensor, optim
from torch.autograd import Variable
import torch.nn.functional as F
from typing import (
    TypeVar, Type, Union, Optional, Any,
    List, Dict, Tuple, Callable, NamedTuple
)

import numpy as np

import random
import time
import os
import copy
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import itertools

from utils import Args, D, timeit

logger = logging.getLogger(__name__)
logging.basicConfig(
    level=10, format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')


def inf_loop() -> None:
    while True:
        yield None


class Sample(NamedTuple):
    """
    数据集中的sample，由数据和标签构成
    tag 表示该数据来源，例如:
    训练集下的 49-15 sample 的 tag 为 train-49-15
    """
    data: List[Tuple[float, int]]  # 数据文件中所有传输记录 [Time,UpOrDown]
    label: int  # 标签, 0-49
    tag: str  # 记号

    def to_tensor_sample(self) -> Tuple[Tensor, int]:
        """
        转为 (Tensor,int)，数据部分为 (Channel,L) 格式
        时间和上下行流量标签各占一个 channel
        :return:
        """
        _data = Tensor(self.data).t()
        _label = self.label
        return _data, _label


class RawDataSet(NamedTuple):
    """
    原始数据集
    """
    train: List[Sample]
    test: List[Sample]


def read_data(data_dir: str, max_workers: int = 12,
              num_train: Optional[int] = None, num_test: Optional[int] = None) -> RawDataSet:
    """
    读取原始数据,全部加载到内存。
    :param data_dir: 数据集目录
    :param max_workers: 读取数据的线程数
    :param num_train: 读取的训练文件数, None 表示全部
    :param num_test: 读取的测试文件数, None 表示全部
    :return: 数据集
    """
    train_dir = os.path.join(data_dir, "defence")
    test_dir = os.path.join(data_dir, "undefence")
    train_files = os.listdir(train_dir)
    test_files = os.listdir(test_dir)
    file_name_pattern = re.compile(r"(\d+?)\-(\d+?)")

    def fn_train_tag(s):
        return f"train-{s}"

    def fn_test_tag(s):
        return f"test-{s}"

    # print(len(train_files), len(test_files)) # 4500 4500

    def build_raw_dataset(_dir: str, files: List[str], fn_tag: Callable[[str], str],
                          num_samples: Optional[int] = None) -> List[Sample]:
        samples = [None] * len(files)

        def f(file_name: str, idx: int):
            _res = file_name_pattern.findall(file_name)
            if len(_res) <= 0:
                # 文件名不符合规定，直接跳过
                return
            file_path = os.path.join(_dir, file_name)
            _label, _ = _res[0]
            _label = int(_label)
            _data = list()
            with open(file_path, "r") as fr:
                for line in fr.readlines():  # 依次读取每行
                    line = line.strip()  # 去掉每行头尾空白
                    _time_str, _stream_type_str = tuple(
                        re.split(r"\s+?", line))
                    _time, _stream_type = (
                        float(_time_str), int(_stream_type_str))
                    _data.append((_time, _stream_type))
            assert len(_data) > 0  # _data 不能为空
            sample = Sample(_data, _label, fn_tag(file_name))
            # samples.append(sample)
            samples[idx] = sample

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            _futures = []
            for i, file_name in enumerate(files):
                if num_samples is not None and i >= num_samples:
                    break
                fut = executor.submit(f, file_name, i)
                _futures.append(fut)
            futures.wait(_futures, return_when=futures.ALL_COMPLETED)
        # 去除读取失败的
        samples = [ele for ele in samples if ele is not None]
        return samples

    train_dataset = build_raw_dataset(
        train_dir, train_files, fn_train_tag, num_train)
    test_dataset = build_raw_dataset(
        test_dir, test_files, fn_test_tag, num_test)
    return RawDataSet(train_dataset, test_dataset)


def calc_accuracy(fn_predict: Callable[[Sample], int], sample_ls: List[Sample]) -> float:
    """
    :param fn_predict: 需要测试的函数
    :param sample_ls: 测试数据集（sample的列表）
    :return: top1 准确率
    """
    n = 0
    right = 0
    for sample in sample_ls:
        n += 1
        pred = fn_predict(sample)
        if pred == sample.label:
            right += 1
    acc = right / n
    return acc 

